# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pydantic import AliasChoices
from pydantic import ConfigDict
from pydantic import Field
from pydantic import computed_field

from nat.builder.builder import Builder
from nat.builder.llm import LLMProviderInfo
from nat.cli.register_workflow import register_llm_provider
from nat.data_models.common import OptionalSecretStr
from nat.data_models.llm import LLMBaseConfig
from nat.data_models.optimizable import OptimizableField
from nat.data_models.optimizable import SearchSpace
from nat.data_models.retry_mixin import RetryMixin
from nat.data_models.thinking_mixin import ThinkingMixin


class AzureOpenAIModelConfig(
        LLMBaseConfig,
        RetryMixin,
        ThinkingMixin,
        name="azure_openai",
):
    """An Azure OpenAI LLM provider to be used with an LLM client."""

    model_config = ConfigDict(protected_namespaces=(), extra="allow")

    api_key: OptionalSecretStr = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
    api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
    azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
                                       serialization_alias="azure_endpoint",
                                       default=None,
                                       description="Base URL for the hosted model.")
    azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
                                  serialization_alias="azure_deployment",
                                  description="The Azure OpenAI hosted model/deployment name.")
    seed: int | None = Field(default=None, description="Random seed to set for generation.")
    temperature: float | None = OptimizableField(
        default=None,
        ge=0.0,
        description="Sampling temperature to control randomness in the output.",
        space=SearchSpace(high=0.9, low=0.1, step=0.2))
    top_p: float | None = OptimizableField(default=None,
                                           ge=0.0,
                                           le=1.0,
                                           description="Top-p for distribution sampling.",
                                           space=SearchSpace(high=1.0, low=0.5, step=0.1))

    @computed_field
    @property
    def model_name(self) -> str:
        """
        Returns the model name for compatibility with other parts of the code base which expect a model_name attribute.
        """
        return self.azure_deployment


@register_llm_provider(config_type=AzureOpenAIModelConfig)
async def azure_openai_llm(config: AzureOpenAIModelConfig, _builder: Builder):

    yield LLMProviderInfo(config=config, description="An Azure OpenAI model for use with an LLM client.")
